/*
    Copyright 2010-2011 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/** \file

    Definitions of the functions used to interface with Arepo. */

// include order is picky here. something in the arepo headers appear
// to clobber the blitz headers, so include all non-arepo headers
// first.

#include "mcrx-arepo.h"
#include "blitz/numinquire.h"
#include "mcrx-debug.h"
#include "misc.h"
#include "units.h"
#include "constants.h"
#include "mono_poly_abstract.h"
#include "vecops.h"
#include <hdf5.h>
#include <boost/mpi/nonblocking.hpp>
#include <boost/mpi/communicator.hpp>
#include <boost/mpi/skeleton_and_content.hpp>
#include "mpi_util.h"
#include "tbb/tick_count.h"

// KLUDGE: we don't need these two here but they are included by arepo
// and they won't compile in C-linkage mode
#include "mpi.h"
#include "gmp.h"
extern "C" {
#include "arepoconfig.h"
#include "proto.h"
#include "allvars.h"
}

// These are declared here because they expose the Arepo data structures.
namespace arepo {
  class bookkeeping;

  /** A pointer to the main Arepo tessellation object. */
  tessellation* T;

  int find_next_tetra(const mcrx::ray_base& r, int t, point& ppstart, 
		      point& ppend, point& ppexit, 
		      int& next_tetra, int& previous_tetra);
  void set_point(const mcrx::vec3d& p, point& pptr, bool si=true);
  mcrx::vec3d get_point(const point& pptr);
};

using namespace std;
using mcrx::vec3d;
using mcrx::vec3i;
using mcrx::T_float;

using arepo::arepo_units;
using arepo::lcon;

/** The units that the Arepo routines return their quantities
    as. (This is the Arepo quantities multiplied by the appropriate
    *con conversion factors below.) */
mcrx::T_unit_map arepo::arepo_units;
T_float arepo::lcon;
T_float arepo::mcon;
T_float arepo::tcon;

/** This vectors maps dp_idx to the dp_idx of the primary cell. */
//std::vector<int> arepo::primary_cells;
/** This vector stores the starting midpoint and the number of faces
    for the cell, indexed by dp_idx. */
//std::vector<std::pair<int, int> > arepo::midpoint_idx;
/** This vector stores the midpoints between the mesh points, indexed
    by midpoint_idx. */
//std::vector<mcrx::vec3d> arepo::midpoints;
/** This vector stores the dp index of the cell opposite to the
    midpoint. */
//std::vector<int> arepo::opposite_points;

mcrx::vec3d arepo::arepomin;
mcrx::vec3d arepo::arepomax;

class arepo::bookkeeping {
private:
  /// Flag determining whether we need to do cleanup in destructor
  static bool inited;

public:
  ~bookkeeping() {
    if(inited) {
      H5close();
      H5dont_atexit();
      if(inited)
	MPI_Finalize();
      inited=false;
    }
  };

  static void init(int* pargc, char*** pargv) {
    if(!inited) {
      int mpi_inited;
      MPI_Initialized(&mpi_inited);
      if(!mpi_inited) {
	// Initialize MPI
	MPI_Init(pargc, pargv);
	inited=true;
      }
    }
  };
};

/// global bookkeeping instance
arepo::bookkeeping bookkeeper;
bool arepo::bookkeeping::inited=false;

/** Perform Arepo initialization, which means call the bookkeeper to
    init MPI. */
void arepo::init(int* pargc, char*** pargv) {
  bookkeeping::init(pargc, pargv);
};


/** Sets the Arepo point to the specified vec3d. Converts to internal
    arepo units. */
void arepo::set_point(const mcrx::vec3d& p, point& pptr, bool si) {
  pptr.x = p[0]/lcon; pptr.y = p[1]/lcon;   pptr.z = p[2]/lcon;
  if(si) 
    set_integers_for_pointer(&pptr);
}

/** Converts the point in the Arepo DP array to a vec3d, converting
    to specified length units. */      
mcrx::vec3d arepo::get_point(const point& pptr) {
  return mcrx::vec3d(pptr.x, pptr.y, pptr.z)*lcon;
}  

/** Converts the point in the Arepo DP array to a vec3d, converting
    to specified length units. */      
mcrx::vec3d arepo::get_point(int p) {
  point* pptr=&T->DP[p];
  return mcrx::vec3d(pptr->x, pptr->y, pptr->z)*lcon;
}  

/** Returns true if the mesh point DP[dp_idx] is a primary cell. */
bool arepo::primary_cell(int dp_idx) {
  return Mesh.DP[dp_idx].task == ThisTask && 
    Mesh.DP[dp_idx].index>=0 && Mesh.DP[dp_idx].index<N_gas;
}

/** Returns the entry into the SphP array for the specified DP mesh
    point. For primary cells, this is .index, but for local ghosts it
    is .index - N_gas. If the point is on another task, return -1. */
int arepo::sph_index(int dp_idx) {
  if (Mesh.DP[dp_idx].task == ThisTask) {
    const int sph_idx = Mesh.DP[dp_idx].index;
    if(sph_idx<N_gas)
      return sph_idx;
    else
      // local ghost
      return sph_idx - N_gas;
  }
  
  return -1;
}


/** If pos is outside of the box, wrap it around the periodic boundary
    so that it is inside.  Note that this works on coordinates in
    AREPO units. */
vec3d arepo::periodic_wrap_point(const vec3d& pos)
{
  vec3d newpos(pos);

#ifdef PERIODIC
  if(newpos[0] > boxSize_X)
    newpos[0] -= boxSize_X;
  if(newpos[0] < 0)
    newpos[0] += boxSize_X;

  if(newpos[1] > boxSize_Y)
    newpos[1] -= boxSize_Y;
  if(newpos[1] < 0)
    newpos[1] += boxSize_Y;

  if(newpos[2] > boxSize_Z)
    newpos[2] -= boxSize_Z;
  if(newpos[2] < 0)
    newpos[2] += boxSize_Z;
#endif
  return newpos;
}


/** If pos is more than a half-box away from ref, wrap it around the
    periodic boundary so that they are less than a half-box
    apart. This transforms between local ghost cells and primary
    cells. Note that this works on coordinates in AREPO units. */
vec3d arepo::periodic_wrap_point(const vec3d& pos, const vec3d& ref)
{
  vec3d newpos(pos);
  const vec3d delta = pos-ref;

#ifdef PERIODIC
  if(delta[0] > boxHalf_X)
    newpos[0] -= boxSize_X;
  if(delta[0] < -boxHalf_X)
    newpos[0] += boxSize_X;
  if(delta[1] > boxHalf_Y)
    newpos[1] -= boxSize_Y;
  if(delta[1] < -boxHalf_Y)
    newpos[1] += boxSize_Y;
  if(delta[2] > boxHalf_Z)
    newpos[2] -= boxSize_Z;
  if(delta[2] < -boxHalf_Z)
    newpos[2] += boxSize_Z;
#endif
  return newpos;
}


/** Returns the difference between two coordinates, wrapping points that
    are more than a half-box apart across the periodic boundary.  Note
    that this works on coordinates in AREPO units. */
vec3d arepo::get_periodic_delta(const vec3d& p1, const vec3d& p2)
{
  vec3d delta (p1-p2);

#ifdef PERIODIC
  if(delta[0] > boxHalf_X)
    delta[0] -= boxSize_X;
  if(delta[0] < -boxHalf_X)
    delta[0] += boxSize_X;
  if(delta[1] > boxHalf_Y)
    delta[1] -= boxSize_Y;
  if(delta[1] < -boxHalf_Y)
    delta[1] += boxSize_Y;
  if(delta[2] > boxHalf_Z)
    delta[2] -= boxSize_Z;
  if(delta[2] < -boxHalf_Z)
    delta[2] += boxSize_Z;
  assert(abs(delta[0])<=boxHalf_X);
  assert(abs(delta[1])<=boxHalf_Y);
  assert(abs(delta[2])<=boxHalf_Z);
#endif
  return delta;
}



/** This function initializes MPI and tells Arepo to load the
    snapshot. */
void arepo::load_snapshot(const std::string& snapfn, 
			  const std::string& parfile, 
			  const mcrx::T_unit_map& units)
{
  cout << "Calling Arepo to load snapshot " << snapfn << endl;

  // Init Arepo MPI variables
  boost::mpi::communicator world;
  ThisTask = world.rank();
  NTask = world.size();
  for(PTask = 0; NTask > (1 << PTask); PTask++);

  T=&Mesh;

  // turn off arepo dumping eos and cooling curve
  WriteMiscFiles = 0;

  strcpy(ParameterFile,parfile.c_str());
  RestartFlag = SUNRISE;
  RestartSnapNum = -1;
  begrun1();
  char s[500];
  // remove possible suffix from snap name since arepo decides this
  // based on the format specified in the parfile
  strcpy(s,strip_suffix(snapfn).c_str());
  
  // load file, but only use gas particles
  read_ic(s, 0x01);
  int status=::init();
  assert(status==SUNRISE);
  cout << "Arepo returned successfully\n  Mesh contains " << T->Ndp
	    << " Delaunay points, " << N_gas << " hydro cells\n";

  // This was done in make_3d_voronoi_projected_image
  T->DTC = reinterpret_cast<tetra_center*>(mymalloc_movable(&T->DTC, "DTC", 
							    T->MaxNdt * sizeof(tetra_center)));
  T->DTF = reinterpret_cast<char*>(mymalloc_movable(&T->DTF, "DTF",
						    T->MaxNdt * sizeof(char)));

  tetra_center *DTC = T->DTC;
  char *DTF = T->DTF;
  point *DP = T->DP;
  tetra *DT = T->DT;
  
  for(int i = 0; i < T->Ndt; i++)
    DTF[i] = 0;

  compute_circumcircles(T);

  // init units and define conversion factors
  // we need to include h-inverse here
  arepo_units["length"] = units.get("length");
  arepo_units["time"] = units.get("time");
  arepo_units["mass"] = units.get("mass");
  
  const T_float hinv=1./All.HubbleParam;
  lcon = units::convert("cm",arepo_units["length"])*All.UnitLength_in_cm*hinv;
  mcon = units::convert("g",arepo_units["mass"])*All.UnitMass_in_g*hinv;
  tcon = units::convert("s",arepo_units["time"])*All.UnitTime_in_s*hinv;

  arepomin=0;
  arepomax=boxSize*lcon;

  // build the streamlined data structure for Voronoi planes.
  //calculate_midpoints();

  // Apply additional gradient limit to ensure the density field is
  // non-negative everywhere.
  ensure_positive_densities();
}


/** Locates the nearest mesh point through a brute-force search of all
    mesh points. This is only used for debugging to avoid relying on
    the possibly-broken mesh connectivity graph. */
int arepo::locate_brute(const mcrx::vec3d& p)
{
  const point* const DP = arepo::T->DP;
  vec3d celldist(p-arepo::get_point(0));
  double dist2min = dot(celldist, celldist);
  int dp_idx=0;
  for(int i=1; i<arepo::T->Ndp; ++i) {
    celldist = p-arepo::get_point(i);

    const double dist2 = dot(celldist, celldist);
    if(dist2 < dist2min) {
      dist2min = dist2;
      dp_idx = i;
    }
  }

  DEBUG(2,cout << "\tFound " << dp_idx << endl;);
  return dp_idx;
}


/** Sets up the midpoints array which stores the midpoints of the face
    planes in a compact way so we don't have to chase a bunch of
    pointers all over the place to do the intersection tests with the
    face planes when finding the voronoi neighbors. The connections
    stored in our data structure differs from that in the Arepo DC
    array in that we map connections between distinct mesh points, not
    just primary cells. */
/*
void arepo::calculate_midpoints()
{
  cout << "Setting up Voronoi cell connectivity structures\n";
  const point* const DP = T->DP;

  // The first thing we must do is to set up a temporary multimap of
  // sphp_idx->dp_idx, so we can identify all local ghosts associated
  // with a sph entry.
  cout << "\tReverse-mapping SPH to DP entries to find local ghosts" << endl;
  typedef multimap<int, int> mmint;
  mmint sphp_dp_map;
  for(int dp_idx=0; dp_idx<T->Ndp; ++dp_idx) {
    const int sphp_idx=sph_index(dp_idx);
    sphp_dp_map.insert(make_pair(sphp_idx, dp_idx));
  }

  // Now that we have the sphp->dp mapping, we can set up the mapping
  // from dp indices to the dp indices of the primary cells.
  for(int dp_idx=0; dp_idx<T->Ndp; ++dp_idx) {
    const int sphp_idx=sph_index(dp_idx);
    if(sphp_idx<0)
      // cell has no hydro quantities, so this doesn't make sense
      primary_cells.push_back(-1);

    // loop over all dp indices that share this sphp cell
    pair<mmint::iterator, mmint::iterator> dp_indices
      (sphp_dp_map.equal_range(sphp_idx));
    assert(dp_indices.second != dp_indices.first);
    while( dp_indices.first != dp_indices.second) {
      // when we find the primary cell, save it
      if (primary_cell(dp_indices.first->second)) {
	primary_cells.push_back(dp_indices.first->second);
	break;
      }
      ++dp_indices.first;
    }
  }
  assert(primary_cells.size()==T->Ndp);

  // use VF array to generate connections. this is just a matter of
  // reorganizing it to index by point.
  cout << "\tBuilding connection multimap" << endl;
  mmint connections;
  for(int i=0;i<T->Nvf; ++i) {
    connections.insert(make_pair(T->VF[i].p1, T->VF[i].p2));
    connections.insert(make_pair(T->VF[i].p2, T->VF[i].p1));
  }
  cout << "\tAssociating connections with cells" << endl;
  for(int dp_idx=0; dp_idx<T->Ndp; ++dp_idx) {
    // check if it's a valid cell
    if(sph_index(dp_idx)<0)
      continue;

    const vec3d cell_p(get_point(dp_idx));

    // Find the connections connecting to this cell
    pair<mmint::iterator, mmint::iterator> dp_neighbors
      (connections.equal_range(dp_idx));
    assert(dp_neighbors.first != dp_neighbors.second);

    for( ;dp_neighbors.first != dp_neighbors.second; ++dp_neighbors.first) {
      const int dp_neighbor = dp_neighbors.first->second;
      if(sph_index(dp_neighbor)<0)
	continue;

      const vec3d midpoint(0.5*(cell_p + get_point(dp_neighbor)));

      // now add the connection to the data structure
      midpoints.push_back(midpoint);
      opposite_points.push_back(dp_neighbor);
    }

    // now that we've done all the connections for DP cell dp_idx,
    // update the number and position into the midpoint index vector.
    const int start_pos = 
      midpoint_idx.empty() ? 0 :
      (midpoint_idx.back().first + midpoint_idx.back().second);
    midpoint_idx.push_back(make_pair(start_pos, midpoints.size()-start_pos));
  }
  

  cout << midpoints.size() << " connections generated" << endl;
}
*/

/** This function is closely modeled after Arepo's
    treefind_overlapping_domains. */
bool arepo::treefind_overlapping_domains(const vec3d& p, MyFloat mindist, 
					 vector<bool>& taskmask)
{
  int no = All.MaxPart;
  struct NODE *current;
  bool any_overlapping=false;
#ifdef PERIODIC
  MyDouble xtmp;
#endif

  taskmask.assign(NTask, false);

  while(no >= 0) {
    if(no < All.MaxPart) {
      // we should never get to single particles
      assert(0);
    }
    else {
      if(no >= All.MaxPart + MaxNodes)	/* pseudo particle */
	{
	  const int task = DomainTask[no - (All.MaxPart + MaxNodes)];
	  assert(task>=0);
	  assert(task<NTask);
	  assert(task!=ThisTask);
	  taskmask[task]=true;
	  any_overlapping=true;

	  no = Nextnode[no - MaxNodes];
	  continue;
	}
	
      current = &Nodes[no];
      assert(current->u.d.bitflags & (1 << BITFLAG_TOPLEVEL));
	
      if(current->Ti_current != All.Ti_Current)
	force_drift_node(no, All.Ti_Current);
	
      no = current->u.d.sibling;	/* in case the node can be discarded */

      double dist = mindist + 0.5 * current->len;;
      double dx = NGB_PERIODIC_LONG_X(current->center[0] - p[0]);
      if(dx > dist)
	continue;
      double dy = NGB_PERIODIC_LONG_Y(current->center[1] - p[1]);
      if(dy > dist)
	continue;
      double dz = NGB_PERIODIC_LONG_Z(current->center[2] - p[2]);
      if(dz > dist)
	continue;
      /* now test against the minimal sphere enclosing everything */
      dist += FACT1 * current->len;
      if(dx * dx + dy * dy + dz * dz > dist * dist)
	continue;
	
      /* We don't open the node if:
          * next is a particle
	  * next is not toplevel,
	  * UNLESS next is a pseudonode
      */
      int next = current->u.d.nextnode;
      bool next_is_part = (next < All.MaxPart);
      int next_is_pseudo = (next >= All.MaxPart+MaxNodes);
      
      if( !next_is_pseudo &&
	  (next_is_part || 
	   !(Nodes[next].u.d.bitflags & (1 << BITFLAG_TOPLEVEL))) )
	continue;

      no = next;
    }
  }
  return any_overlapping;
}


/** Returns a pair of (index, task) for the cell containing the point
    p. If a negative value is returned, the point is outside the
    grid. The search is modeled on find_nearest_meshpoint_global, but
    since it only does one position at a time and needs to be
    nonblocking it's a bit different. The thread parameter is used to
    distinguish location requests that require MPI messages.
 */
std::pair<int,int> arepo::locate(const mcrx::vec3d& p, int thread,
				 bool accept_outside)
{
  namespace mpi = boost::mpi;

  point *DP = T->DP;
  tetra *DT = T->DT;
  const int Ndp = T->Ndp;

  if(!accept_outside && !in_arepo_box(p))
    return make_pair(-1,-1);
  
  vec3d pos = p*(1./arepo::lcon);

  MyFloat mindist=-1;
  int search_done;
  int idx = find_nearest_meshpoint_local(&pos[0], -1, &mindist,
					 1, &search_done);

  vector<bool> taskmask;
  bool any_overlapping=false;
  if(!search_done) {
    // The search radius overlaps with other tasks. We must now walk
    // the tree again to see if the final mindist actually overlaps
    // with any tasks and, if so, which they are. Unfortunately we
    // can't use treefind_overlapping_domains because it is not thread
    // safe, so we use this C++ function instead. Note that just
    // because we come out here, we don't necessarily *have* to talk
    // to anyone else. The returned mask may be empty.
    any_overlapping = treefind_overlapping_domains(pos, mindist, taskmask);
  }

  if(!any_overlapping) {
    // we know the point is on this task. no communication necessary
    DEBUG(1,printf("Thread %d-%d found location without comms\n",ThisTask, thread));
    return make_pair(idx, ThisTask);
  }

  // we are indeed overlapping with other tasks, so we must talk.

  mpi::communicator world;
  tbb::tick_count now(tbb::tick_count::now());
  
  // sends requests to the overlapping tasks
  int n=0;
  const location_request request_data= location_request(pos, mindist);
  vector<mpi::request> send_reqs;
  send_reqs.reserve(taskmask.size());
  for(int i=0; i<taskmask.size(); ++i) {
    if(taskmask[i]) {
      DEBUG(1,printf("Worker %d-%d posting location request to task %d\n",ThisTask, thread, i);cout.flush(););
      send_reqs.push_back(world.isend(i, mcrx::request_thread2tag(thread), 
				      mpi::get_content(request_data)));
      ++n;
    }
  }
  assert(n>0);

  // wait on the outgoing requests
  mpi::wait_all(send_reqs.begin(), send_reqs.end());

  // and then immediately post receives for the responses
  vector<mpi::request> reqs;
  reqs.reserve(n);
  vector<location_response> responses(n);
  n=0;
  for(int i=0; i<taskmask.size(); ++i) {
    if(taskmask[i]) {
      reqs.push_back(world.irecv(i, mcrx::response_thread2tag(thread), 
				 mpi::get_content(responses[n++])));
    }
  }

  // wait for them to complete
  vector<mpi::status> stats(n);
  mpi::wait_all(reqs.begin(), reqs.end(), stats.begin());

  DEBUG(1,printf("Worker %d-%d location request took %7.2gms\n",ThisTask, thread, (tbb::tick_count::now()-now).seconds()*1000););

  // and then process the responses
  int task=ThisTask;
  for(int i=0; i<stats.size(); ++i) {
    if(responses[i].mindist<mindist) {
      mindist=responses[i].mindist;
      idx=responses[i].idx;
      task=stats[i].source();
      assert(stats[i].source()==responses[i].task);
    }
  }

  DEBUG(1,printf("Worker %d-%d: location owned by task %d\n",ThisTask, thread, task);cout.flush(););

  return make_pair(idx, task);
}


vector<bool> arepo::position_ownership(vector<vec3d>& positions)
{
  // pick out the positions that are on this domain. this is easy
  // enough, and since we in general will need to do a search anyway
  // because the particles are fuzzy, there doesn't seem to be much to
  // gain by doing an actual location.
  vector<bool> own;

  for(int i=0; i<positions.size(); i++) {
    // to ensure that we get a valid result for positions outside the
    // box, we truncate the position to the box
    vec3d p=positions[i];

    if(!in_arepo_box(p)) {
      p = mcrx::truncate_to_box(p, arepomin, arepomax);
    }
    // convert to arepo units
    p*= (1./arepo::lcon);

    // First find peano-hilbert key for the position
    peanokey key = position_to_peanokey(&p[0]);
    // and the topnode containing the peanokey
    int node = peanokey_to_topnode(key);

    own.push_back(DomainTask[node]==ThisTask);
  }

  return own;
}
  

arepo::arepo_pressure_functor::arepo_pressure_functor() :
  // conversion factor for P -> P/k is to convert pressure to cgs and
  // then divide by Boltzmann's constant in erg/K. Also need to take h
  // into account, and pressure scales like h^2.
  pcon(arepo::mcon/(arepo::lcon*arepo::tcon*arepo::tcon*
		    constants::k*units::convert("J/K","erg/K"))),
  dcon(arepo::mcon/(arepo::lcon*arepo::lcon*arepo::lcon)*
       units::convert("g/cm^3",arepo_units.get("mass")+"/"+
		      arepo_units.get("length")+"^3"))
{}


/** Returns the P/k and density (in cgs units!) and metallicity in the
    grid cell at the specified position. */
pair<T_float,pair<T_float, T_float> > 
arepo::arepo_pressure_functor::operator()(const mcrx::vec3d& pos) const 
{
  assert(NTask==1);
  const pair<int,int> loc = locate(pos,0);
  const int sphp_idx= loc.first;
  if(sphp_idx>=0)
    return make_pair(SphP[sphp_idx].Pressure*pcon, 
		     make_pair(SphP[sphp_idx].Density*dcon, 
			       SphP[sphp_idx].Metallicity));
  else
    // particle is outside the grid, we have no pressure info
    return make_pair(blitz::tiny(T_float()), make_pair(blitz::tiny(T_float()), 0.));
}


/* Finds the next Voronoi cell a ray will enter. Returns a pair
   containing the index and task of the Voronoi cell bordering sphp_idx
   where the ray r intersects the cell face. len is set to the length
   to the intersection point. If the ray exits the Arepo box before
   hitting a cell boundary, next is set to -1 and len is the length to
   the boundary. 

   This function no longer uses the special Sunrise data structure for
   cell edges, instead using the Arepo data. This is becasue the
   sunrise data is hard-coded to be single task.
 */
pair<int, int> arepo::find_next_cell(const mcrx::ray_base& r, 
				     int sphp_idx, int previous, 
				     mcrx::T_float& len)
{
  point * const DP = T->DP;

  // get position in Arepo units.
  vec3d pp(r.position()*(1./lcon));

  // get the edge to the next cell with the arepo function
  int edge = find_next_voronoi_cell(T, sphp_idx, &(pp[0]), 
				    const_cast<double*>(&(r.direction()[0])),
				    previous, &len);

  // because the voronoi cells wrap across the boundary, we will also
  // encounter a different task just to traverse a fraction of a
  // cell. This is extremely wasteful, so if we are entering another
  // task and the mesh point we enter is across the boundary, we just
  // ignore it and pretend we are exiting. This introduces a small
  // error in the column densities, but if there is significant
  // density in those cells things are screwed up anyway.
  if(DC[edge].task != ThisTask) {
    // the voronoi ghost point will be outside the box
    const int dp_idx = DC[edge].dp_index;
    const vec3d other_pos(DP[dp_idx].x,DP[dp_idx].y,DP[dp_idx].z);
    if(!in_arepo_box(other_pos*lcon))
      return make_pair(-1, -1);
  }

  // transform length to sunrise units
  len *= lcon;

  assert(edge>=0);

  return make_pair(DC[edge].index, DC[edge].task);
}


/** Returns the dp index of the point out of the specified point and
    all its voronoi neighbors that is closest to the point p. The
    function works ok with ghost points. */
/*
int arepo::find_closest_neighbor_in_cell(const mcrx::vec3d& p, int dp_idx)
{
  const point* const DP = T->DP;

  vec3d celldist(p-get_point(dp_idx));

  double dist2min = dot(celldist, celldist);

  const int start_edge = midpoint_idx[dp_idx].first;
  const int n_edges = midpoint_idx[dp_idx].second;

  DEBUG(2,cout<<"Testing for closest neighbor of cell " << dp_idx << endl;);
  
  for(int i=0; i < n_edges; ++i) {
    const int dp_neighbor = opposite_points[start_edge+i];

    DEBUG(3,cout<<"\tedge " << i << ", neighbor " << dp_neighbor << endl;);

    celldist = p-get_point(dp_neighbor);

    const double dist2 = dot(celldist, celldist);
    if(dist2 < dist2min) {
      dist2min = dist2;
      dp_idx = dp_neighbor;
    }
  }

  DEBUG(2,cout << "\tFound " << dp_idx << endl;);
  return dp_idx;
}
*/

/** Determines if a point is inside a cell by checking the
    intersections with the face planes. This is closely modeled on
    find_next_voronoi_cell in Arepo. The point should be in Arepo
    units.  */
bool arepo::point_inside(int sph_idx, const vec3d& pos)
{
  point *DP = T->DP;
  myassert(DP[sph_idx].index >= 0);
  myassert(DP[sph_idx].index < T->Ndp);

  vec3d cell_p(P[sph_idx].Pos[0], P[sph_idx].Pos[1], P[sph_idx].Pos[2]);

  // if mesh point is across the boundary, wrap it
  periodic_wrap_point(cell_p, pos);

  int edge = SphP[sph_idx].first_connection;
  const int last_edge = SphP[sph_idx].last_connection;
  bool inside=true;

  while(true) {
    const int neighbor = DC[edge].dp_index;
    myassert((DC[edge].task!=ThisTask) || (DC[edge].index!=sph_idx));

    vec3d nb_p (DP[neighbor].x, DP[neighbor].y, DP[neighbor].z);
    // if neighbor is across the boundary, wrap it
    periodic_wrap_point(nb_p, pos);

    // m is the edge midpoint, which is a point on the face plane
    const vec3d m(0.5*(nb_p + cell_p));

    // c is the vector from pos to the point on the face
    const vec3d c(m-pos);

    // q is the edge vector to the neighboring cell, which is a
    // normal vector to the plane
    const vec3d q(nb_p - cell_p);

    // The point is on the inside of this face if c.q>0.
    const T_float cdotq = dot(c,q);

    if(cdotq < 0) {
      inside=false;
      break;
    }

    if(edge == last_edge)
      break;

    edge = DC[edge].next;
  }
  return inside;
}


/** Solve for the fraction of the distance to get a specified fraction
    of the total column density along the entire segment through a
    single Voronoi cell. The dp_idx defines the cell, fn is the column
    density fraction, pos and dir are the position and direction of
    the start of the line segment.

    The fraction depends on the length of the total segment L and the
    quantity alpha = (density at starting point)/((density gradient
    . direction) * L).

    By integrating the column density along the ray, we get that the
    fractional distance is given by the solution to the quadratic
    equation

    fl = -alpha +- sqrt( alpha^2 + (2*alpha+1)*fn )

*/
mcrx::T_float
arepo::column_to_distance_fraction(int idx, mcrx::T_float fn, 
				   const mcrx::vec3d& pos, 
				   const mcrx::vec3d& dir, mcrx::T_float L) 
{
  using arepo::lcon;
  assert(L>0);
  assert(fn>=0);
  assert(fn<=1);
  assert_contains(idx, pos);
  assert_contains(idx, pos+dir*L);

  const vec3d cell_center(SphP[idx].Center[0],
			  SphP[idx].Center[1],
			  SphP[idx].Center[2]);

  const vec3d delta_center = 
    get_periodic_delta((1./lcon)*pos, cell_center);

  const vec3d arepo_gradient (SphP[idx].Grad.drho[0],
			      SphP[idx].Grad.drho[1],
			      SphP[idx].Grad.drho[2]);

  // we don't actually need to know the sunrise densities for this
  // calculation, because the solution only depends on the relative
  // gradient. (since we assume no gradient in metallicity.)
  const T_float rho_start(SphP[idx].Density + 
			  dot(arepo_gradient, delta_center));

  // note coordinates here are in AREPO units

  const T_float alpha = rho_start/(dot(arepo_gradient, dir)*(L/lcon));

  // in principle the density should never be zero if we are
  // scattering, but there are problems with the gradients that cause
  // the density to undershoot
  assert(rho_start>=0);
#ifndef NDEBUG
  if(1.+1./alpha<0) {
    // this means density undershoots zero at the end point of the segment
    cerr << "Error: Density undershoot at endpoint " << vec3d(pos+L*dir)<< " in cell " << idx << ": " << cell_density(idx, (pos+dir*L)/lcon) << endl;
    const vec3d pexit(pos+dir*L);
    const int nearest=locate_brute(pexit);
    const T_float dist = mcrx::mag(pexit-get_point(idx));
    const T_float dist2 = mcrx::mag(pexit-get_point(nearest));

    cout << "Nearest cell is " << nearest << " dist " << dist2 << " but we are in " << idx << " dist " << dist<< endl;
    assert(0);
  }
#endif

  if( alpha>1e10 || alpha<-1e-10 ) {
    // the gradient is vanishingly small. then we recover the normal solution
    return fn;
  }

  // solve quadratic
  const T_float f1 = -alpha;
  const T_float f2 = sqrt(alpha*alpha + (2*alpha+1)*fn);
  T_float fl=-1;

  if(f1<=0)
    // answer must be high root
    fl=f1+f2;
  else if(f1>=1)
    // must be low root
    fl=f1-f2;
  else {
    // both roots can potentially be valid
    if(f2==0) {
      // double root
      fl=f1;
    }
    else {
      // now we have to check which root it is
      T_float fa=f1+f2;
      T_float fb=f1-f2;

      bool va((fa>=0)&&(fa<=1));
      bool vb((fb>=0)&&(fb<=1));
      // make sure only one solution is valid
      assert(va!=vb);

      // pick valid solution
      fl = va ? fa : fb;
    }
  }
  
  assert((fl>=0)&&(fl<=1));
  return fl;
}


#ifndef NDEBUG
void arepo::assert_contains(int idx, mcrx::vec3d pos)
{
  const vec3d apos((1./lcon)*pos);

  ::assert_contains(T, idx, const_cast<double*>(&(apos[0])));
  /*
  const T_float INSIDE_EPS=5e-4;

		    
  vec3d celldist(pos - get_point(idx));

  const double dist2p = dot(celldist, celldist);
  if(dist2p == 0.) 
    // if we are sitting right on top of the mesh point, we'll get NaN below
    return;
  double dist2min = dist2p;

  const int start_edge = midpoint_idx[dp_idx].first;
  const int n_edges = midpoint_idx[dp_idx].second;
  int c=-1;
  for(int i=0; i < n_edges; ++i) {
    const int dp_neighbor = opposite_points[start_edge+i];

    celldist = pos - get_point(dp_neighbor);

    const double dist2 = dot(celldist, celldist);
    if(dist2 < dist2min) {
      dist2min = dist2;
      c = dp_neighbor;
    }
  }

  const bool pass = dist2p/dist2min < 1+INSIDE_EPS;

  if(!pass) cout << "Failed assert_contains for cell " << dp_idx << ", position " << pos << ", nearest cell " << c << " by " << dist2p/dist2min-1 << endl;

  assert(pass);
  */
};
#endif


/** Returns the density (in Sunrise units) at point pos (in Arepo
    units) in the cell specified by dp_idx. The point is assumed to be
    inside the cell, as the calculation uses the gradient within the
    cell, but this is not checked here. Note the mixed units, this is
    because all points are stored in Arepo units. */
mcrx::T_float arepo::cell_density(int idx, const mcrx::vec3d& pos)
{
  const vec3d cell_center(SphP[idx].Center[0],
			  SphP[idx].Center[1],
			  SphP[idx].Center[2]);

  const vec3d delta_center = 
    get_periodic_delta(pos, cell_center);
  
  const vec3d arepo_gradient(SphP[idx].Grad.drho[0],
			     SphP[idx].Grad.drho[1],
			     SphP[idx].Grad.drho[2]);

  // interpolate density and convert units
  const T_float rho = mcon/(lcon*lcon*lcon)*
    (SphP[idx].Density + dot(delta_center, arepo_gradient)) ;

  assert(rho==rho);

  return rho;
}


/** Loops over the circumsphere centers (ie Voronoi face vertices) and
    applies a further limiting of the density gradient to ensure that
    the density is non-negative at all points in the cell. (The Arepo
    gradient limiter only ensures it is non-negative at the face
    centers.) This is all done in Arepo units.
*/
void arepo::ensure_positive_densities()
{
  tetra* DT = arepo::T->DT;
  point* DP = arepo::T->DP;
  tetra_center* DTC = arepo::T->DTC;

  std::cout << "Ensuring positive definite density field." << std::endl;

  // Start by looping over all tetras
  for(int t=0; t<arepo::T->Ndt; ++t) {
    if( DT[t].t[0]<0 || 
	DT[t].p[0] == DPinfinity || DT[t].p[1] == DPinfinity || 
	DT[t].p[2] == DPinfinity || DT[t].p[3] == DPinfinity )
      // tetras with t[0]<0 or that contain an infinity point do not
      // have computed circumspheres. what does this mean?  It's fine,
      // either they are deleted and the infinity points are well
      // outside our box.
      continue;

    // circumsphere center
    const vec3d dtc(DTC[t].cx, DTC[t].cy, DTC[t].cz);
    
    // For each tetrahedron, we loop over the four vertices. 
    for(int p=0; p<4; ++p) {
      // find the voronoi cell connected to this vertex. Here, we need
      // only concern ourselves with primary mesh points. They may be
      // duplicated as ghosts, but to get all dtc's of a primary mesh
      // point, we need only process tetras that are connected to the
      // primary point itself.  to the primary mesh point
      const int dp_idx =DT[t].p[p];

      if( (DP[dp_idx].task==ThisTask) && 
	  (dp_idx>=0) && (dp_idx<N_gas) &&
	  (SphP[dp_idx].Density>0) ) {

	// valid cell, calculate the density at the vertex
	const T_float rho = cell_density(dp_idx, dtc);
	  
	if (rho<0) {
	  // calculate the max allowed gradient (actually we
	  // overcompensate a bit to lessen the risk of roundoff
	  // leading to a negative density anyway)
	  const T_float rho0 = SphP[dp_idx].Density*mcon/(lcon*lcon*lcon);
	  const T_float alpha = rho0/(rho0-rho)*(1.-1e-6);
	  
	  DEBUG(3,cout << "Applying additional slope limit by " << alpha << " for cell " << dp_idx << " at point " << dtc << endl;);
	  
	  assert(alpha<1);
	  assert(alpha>=0);
	  // apply limit to gradient
	  SphP[dp_idx].Grad.drho[0] *= alpha;
	  SphP[dp_idx].Grad.drho[1] *= alpha;
	  SphP[dp_idx].Grad.drho[2] *= alpha;
	}
      }
    }
  }
}
  
